{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fbcdfea8-7241-46d7-a771-c0381a3e7063",
   "metadata": {},
   "outputs": [],
   "source": [
    "# imports\n",
    "\n",
    "import os\n",
    "import re\n",
    "import math\n",
    "import json\n",
    "from tqdm import tqdm\n",
    "import random\n",
    "from dotenv import load_dotenv\n",
    "from huggingface_hub import login\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pickle\n",
    "from openai import OpenAI\n",
    "from sentence_transformers import SentenceTransformer\n",
    "from datasets import load_dataset\n",
    "import chromadb\n",
    "from items import Item\n",
    "from testing import Tester"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "98666e73-938e-469d-8987-e6e55ba5e034",
   "metadata": {},
   "outputs": [],
   "source": [
    "# environment\n",
    "load_dotenv(override=True)\n",
    "os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY', 'your-key-if-not-using-env')\n",
    "os.environ['HF_TOKEN'] = os.getenv('HF_TOKEN', 'your-key-if-not-using-env')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a25a5cf-8f6c-4b5d-ad98-fdd096f5adf8",
   "metadata": {},
   "outputs": [],
   "source": [
    "openai = OpenAI()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc696493-0b6f-48aa-9fa8-b1ae0ecaf3cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load in the test pickle file\n",
    "with open('test.pkl', 'rb') as file:\n",
    "    test = pickle.load(file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33d38a06-0c0d-4e96-94d1-35ee183416ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_context(similars, prices):\n",
    "    message = \"To provide some context, here are some other items that might be similar to the item you need to estimate.\\n\\n\"\n",
    "    for similar, price in zip(similars, prices):\n",
    "        message += f\"Potentially related product:\\n{similar}\\nPrice is ${price:.2f}\\n\\n\"\n",
    "    return message"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61f203b7-63b6-48ed-869b-e393b5bfcad3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def messages_for(item, similars, prices):\n",
    "    system_message = \"You estimate prices of items. Reply only with the price, no explanation. Price is always below $1000.\"\n",
    "    user_prompt = make_context(similars, prices)\n",
    "    user_prompt += \"And now the question for you:\\n\\n\"\n",
    "    user_prompt += item.test_prompt().replace(\" to the nearest dollar\",\"\").replace(\"\\n\\nPrice is $\",\"\")\n",
    "    return [\n",
    "        {\"role\": \"system\", \"content\": system_message},\n",
    "        {\"role\": \"user\", \"content\": user_prompt},\n",
    "        {\"role\": \"assistant\", \"content\": \"Price is $\"}\n",
    "    ]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b26f405d-6e1f-4caa-b97f-1f62cd9d1ebc",
   "metadata": {},
   "outputs": [],
   "source": [
    "DB = \"products_vectorstore\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d26a1104-cd11-4361-ab25-85fb576e0582",
   "metadata": {},
   "outputs": [],
   "source": [
    "client = chromadb.PersistentClient(path=DB)\n",
    "collection = client.get_or_create_collection('products')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e339760-96d8-4485-bec7-43fadcd30c4d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def description(item):\n",
    "    text = item.prompt.replace(\"How much does this cost to the nearest dollar?\\n\\n\", \"\")\n",
    "    return text.split(\"\\n\\nPrice is $\")[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f759bd2-7a7e-4c1a-80a0-e12470feca89",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e44dbd25-fb95-4b6b-bbbb-8da5fc817105",
   "metadata": {},
   "outputs": [],
   "source": [
    "def vector(item):\n",
    "    return model.encode([description(item)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ffd5ee47-db5d-4263-b0d9-80d568c91341",
   "metadata": {},
   "outputs": [],
   "source": [
    "def find_similars(item):\n",
    "    results = collection.query(query_embeddings=vector(item).astype(float).tolist(), n_results=5)\n",
    "    documents = results['documents'][0][:]\n",
    "    prices = [m['price'] for m in results['metadatas'][0][:]]\n",
    "    return documents, prices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f7b9ff9-fd90-4627-bb17-7c2f7bbd21f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(test[1].prompt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff1b2659-cc6b-47aa-a797-dd1cd3d1d6c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "documents, prices = find_similars(test[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24756d4d-edac-41ce-bb80-c3b6f1cea7ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(make_context(documents, prices))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b81eca2-0b58-4fe8-9dd6-47f13ba5f8ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(messages_for(test[1], documents, prices))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d11f1c8d-7480-4d64-a274-b030d701f1b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_price(s):\n",
    "    s = s.replace('$','').replace(',','')\n",
    "    match = re.search(r\"[-+]?\\d*\\.\\d+|\\d+\", s)\n",
    "    return float(match.group()) if match else 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06743833-c362-47f8-b02a-139be2cd52ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "get_price(\"The price for this is $99.99\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a919cf7d-b3d3-4968-8c96-54a0da0b0219",
   "metadata": {},
   "outputs": [],
   "source": [
    "# The function for gpt-4o-mini\n",
    "\n",
    "def gpt_4o_mini_rag(item):\n",
    "    documents, prices = find_similars(item)\n",
    "    response = openai.chat.completions.create(\n",
    "        model=\"gpt-4o-mini\", \n",
    "        messages=messages_for(item, documents, prices),\n",
    "        seed=42,\n",
    "        max_tokens=5\n",
    "    )\n",
    "    reply = response.choices[0].message.content\n",
    "    return get_price(reply)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b42e1b9-eaa0-4b45-a847-e8932367f596",
   "metadata": {},
   "outputs": [],
   "source": [
    "# The function for gpt-4.1\n",
    "\n",
    "# def gpt_4_1_rag(item):\n",
    "#     documents, prices = find_similars(item)\n",
    "#     response = openai.chat.completions.create(\n",
    "#         model=\"gpt-4.1\", \n",
    "#         messages=messages_for(item, documents, prices),\n",
    "#         seed=42,\n",
    "#         max_tokens=5\n",
    "#     )\n",
    "#     reply = response.choices[0].message.content\n",
    "#     return get_price(reply)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e519e26-ff15-4425-90bb-bfbf55deb39b",
   "metadata": {},
   "outputs": [],
   "source": [
    "gpt_4o_mini_rag(test[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "082c6a5a-0f2a-4941-a465-ffb3137a2e8d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# gpt_4_1_rag(test[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce78741b-2966-41d2-9831-cbf8f8d176be",
   "metadata": {},
   "outputs": [],
   "source": [
    "test[1].price"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16d90455-ff7d-4f5f-8b8c-8e061263d1c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "Tester.test(gpt_4o_mini_rag, test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26d5ddc6-baa6-4760-a430-05671847ac47",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
